64
Algorithms for Binary Neural Networks
3.6.3
Network Pruning
We further prune the 1-bit CNNs to increase model efficiency and improve the flexibility
of RBCNs in practical scenarios. This section considers the optimization pruning process,
including changing the loss function and updating the learnable parameters.
3.6.3.1
Loss Function
After binarizing the CNNs, we prune the resulting 1-bit CNNs under the generative ad-
versarial learning framework using the method described in [142]. We used a soft mask to
remove the corresponding structures, such as filters, while obtaining close to the baseline
accuracy. The discriminator Dp(·) with weights Yp is introduced to distinguish the output
of the baseline network Rp from those Tp of the pruned 1-bit network. The pruned network
with weights Wp, ˆWp, Cp and a soft mask Mp, is learned together with Yp using knowledge
of the supervised features of the baseline. Wp, ˆWp, Cp, Mp and Yp are learned by solving
the optimization problem as follows:
arg
min
Wp, ˆ
Wp,Cp,Mp
max
Yp Lp = LAdv p(Wp, ˆWp, Cp, Mp, Yp) + LKernel p(Wp, ˆWp, Cp)
LS p(Wp, ˆWp, Cp) + LData p(Wp, ˆWp, Cp, Mp) + LReg p(Mp, Yp),
(3.77)
where Lp is the pruning loss function, and the forms of LAdv p(Wp, ˆWp, Cp, Mp, Yp) and
LKernel p(Wp, ˆWp, Cp) are
LAdv p(Wp, ˆWp, Cp, Mp, Yp) = log(Dp(Rp; Yp)) + log(1 −Dp(Tp; Yp)),
(3.78)
LKernel p(Wp, ˆWp, Cp) = λ1/2||Wp −Cp ˆWp||2.
(3.79)
LS p is a traditional problem-dependent loss such as softmax loss. LData p is the data loss
between the output features of the baseline and the pruned network and is used to align
the output of these two networks. The data loss can then be expressed as the MSE loss.
LData p(Wp, ˆWp, Cp, Mp) = 1
2n∥Rp −Tp∥2,
(3.80)
where n is the size of the minibatch.
LReg p(Mp, Yp) is a regularizer on Wp, ˆWp,Mp and Yp, which can be split into two parts
as follows:
LReg p(Mp, Yp) = Rλ(Mp) + R(Yp),
(3.81)
where R(Yp) = log(Dp(Tp; Yp)), Rλ(Mp) is a sparsity regularizer form with parameters λ
and Rλ(Mp) = λ||Mp||l1.
As with the process in binarization, the update of the discriminators is omitted in the
following description until Algorithm 2. We have also omitted log(·) for simplicity and
rewritten the optimization of Eq. 3.77 as
min
Wp, ˆ
Wp,Cp,Mp
λ1/2
l
i ||W l
p,i −Cl ˆW l
p,i||2 +
l
i ||1 −D(T l
p,i; Y )||2
+LS p(Wp, ˆWp, Cp) +
1
2n∥Rp −Tp∥2 + λ||Mp||l1.
(3.82)
3.6.3.2
Learning Pruned RBCNs
In pruned RBCNs, the convolution is implemented as
F l
out,p = RBConv(F l
in,p; ˆW l
p ◦M l
p, Cl
p) = Conv(F l
in,p, ( ˆWp ◦M l
p) ⊙Cl
p),
(3.83)